Skip to content

Conversation

@dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Feb 4, 2025

#124863 added folding support for poison indices to vector.shuffle. This PR adds support for folding vector.shuffle ops with one or two poison input vectors.

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

#124863 added folding support for poison indices to vector.shuffle. This PR adds support for folding vector.shuffle ops with one or two poison input vectors.


Full diff: https://github.com/llvm/llvm-project/pull/125608.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/IR/VectorOps.cpp (+32-11)
  • (modified) mlir/test/Dialect/Vector/canonicalize.mlir (+39)
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 93f89eda2da5a6..8d5691f38f273c 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -26,7 +26,6 @@
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/BuiltinAttributes.h"
-#include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/DialectImplementation.h"
 #include "mlir/IR/IRMapping.h"
@@ -42,7 +41,6 @@
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/StringSet.h"
 #include "llvm/ADT/TypeSwitch.h"
-#include "llvm/ADT/bit.h"
 
 #include <cassert>
 #include <cstdint>
@@ -2696,25 +2694,45 @@ OpFoldResult vector::ShuffleOp::fold(FoldAdaptor adaptor) {
   if (!v1Attr || !v2Attr)
     return {};
 
+  // Fold shuffle poison, poison -> poison.
+  bool isV1Poison = isa<ub::PoisonAttr>(v1Attr);
+  bool isV2Poison = isa<ub::PoisonAttr>(v2Attr);
+  if (isV1Poison && isV2Poison)
+    return ub::PoisonAttr::get(getContext());
+
   // Only support 1-D for now to avoid complicated n-D DenseElementsAttr
   // manipulation.
   if (v1Type.getRank() != 1)
     return {};
 
-  int64_t v1Size = v1Type.getDimSize(0);
+  // Poison input attributes need special handling as they are not
+  // DenseElementsAttr. If an index is poison, we select the first element of
+  // the first non-poison input.
+  SmallVector<Attribute> v1Elements, v2Elements;
+  Attribute poisonElement;
+  if (!isV2Poison) {
+    v2Elements =
+        to_vector(cast<DenseElementsAttr>(v2Attr).getValues<Attribute>());
+    poisonElement = v2Elements[0];
+  }
+  if (!isV1Poison) {
+    v1Elements =
+        to_vector(cast<DenseElementsAttr>(v1Attr).getValues<Attribute>());
+    poisonElement = v1Elements[0];
+  }
 
   SmallVector<Attribute> results;
-  auto v1Elements = cast<DenseElementsAttr>(v1Attr).getValues<Attribute>();
-  auto v2Elements = cast<DenseElementsAttr>(v2Attr).getValues<Attribute>();
+  int64_t v1Size = v1Type.getDimSize(0);
   for (int64_t maskIdx : mask) {
     Attribute indexedElm;
-    // Select v1[0] for poison indices.
     // TODO: Return a partial poison vector when supported by the UB dialect.
     if (maskIdx == ShuffleOp::kPoisonIndex) {
-      indexedElm = v1Elements[0];
+      indexedElm = poisonElement;
     } else {
-      indexedElm =
-          maskIdx < v1Size ? v1Elements[maskIdx] : v2Elements[maskIdx - v1Size];
+      if (maskIdx < v1Size)
+        indexedElm = isV1Poison ? poisonElement : v1Elements[maskIdx];
+      else
+        indexedElm = isV2Poison ? poisonElement : v2Elements[maskIdx - v1Size];
     }
 
     results.push_back(indexedElm);
@@ -3332,13 +3350,15 @@ class InsertStridedSliceConstantFolder final
         !destVector.hasOneUse())
       return failure();
 
-    auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
-
     TypedValue<VectorType> sourceValue = op.getSource();
     Attribute sourceCst;
     if (!matchPattern(sourceValue, m_Constant(&sourceCst)))
       return failure();
 
+    // TODO: Support poison.
+    if (isa<ub::PoisonAttr>(vectorDestCst) || isa<ub::PoisonAttr>(sourceCst))
+      return failure();
+
     // TODO: Handle non-unit strides when they become available.
     if (op.hasNonUnitStrides())
       return failure();
@@ -3355,6 +3375,7 @@ class InsertStridedSliceConstantFolder final
     // increasing linearized position indices.
     // Because the destination may have higher dimensionality then the slice,
     // we keep track of two overlapping sets of positions and offsets.
+    auto denseDest = llvm::cast<DenseElementsAttr>(vectorDestCst);
     auto denseSlice = llvm::cast<DenseElementsAttr>(sourceCst);
     auto sliceValuesIt = denseSlice.value_begin<Attribute>();
     auto newValues = llvm::to_vector(denseDest.getValues<Attribute>());
diff --git a/mlir/test/Dialect/Vector/canonicalize.mlir b/mlir/test/Dialect/Vector/canonicalize.mlir
index 6858f0d56e6412..65c3ab264283d2 100644
--- a/mlir/test/Dialect/Vector/canonicalize.mlir
+++ b/mlir/test/Dialect/Vector/canonicalize.mlir
@@ -2023,6 +2023,45 @@ func.func @shuffle_1d_poison_idx() -> vector<4xi32> {
 
 // -----
 
+// CHECK-LABEL: func @shuffle_1d_rhs_lhs_poison
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   %[[V:.+]] = ub.poison : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_rhs_lhs_poison() -> vector<4xi32> {
+  %v0 = ub.poison : vector<3xi32>
+  %v1 = ub.poison : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_lhs_poison
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   %[[V:.+]] = arith.constant dense<[5, 4, 5, 5]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_lhs_poison() -> vector<4xi32> {
+  %v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
+  %v1 = ub.poison : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
+// CHECK-LABEL: func @shuffle_1d_rhs_poison
+//   CHECK-NOT:   vector.shuffle
+//       CHECK:   %[[V:.+]] = arith.constant dense<[2, 2, 0, 1]> : vector<4xi32>
+//       CHECK:   return %[[V]]
+func.func @shuffle_1d_rhs_poison() -> vector<4xi32> {
+  %v0 = ub.poison : vector<3xi32>
+  %v1 = arith.constant dense<[2, 1, 0]> : vector<3xi32>
+  %shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
+  return %shuffle : vector<4xi32>
+}
+
+// -----
+
 // CHECK-LABEL: func @shuffle_canonicalize_0d
 func.func @shuffle_canonicalize_0d(%v0 : vector<i32>, %v1 : vector<i32>) -> vector<1xi32> {
   // CHECK: vector.broadcast %{{.*}} : vector<i32> to vector<1xi32>

We recently added folding support for poison indices to `vector.shuffle`.
This PR adds support for folding poison inputs.
@dcaballe dcaballe force-pushed the input-poison-shuffle-canon branch from ece1c6d to 46a4887 Compare February 4, 2025 19:34
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

Left a couple of optional nits.

Comment on lines 2044 to 2046
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] There's a value and index == 5, so it's not obvious that the first element of %v0 is in any way significant. Perhaps use some more distinct number? (e.g. 123).

Suggested change
%v0 = arith.constant dense<[5, 4, 3]> : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 5, 4] : vector<3xi32>, vector<3xi32>
%v0 = arith.constant dense<[123, 4, 3]> : vector<3xi32>
%v1 = ub.poison : vector<3xi32>
%shuffle = vector.shuffle %v0, %v1 [3, 1, 123, 4] : vector<3xi32>, vector<3xi32>

I appreciate that this is obvious right now, but lets also cater for our future selves :)

Comment on lines +2696 to +2698
// Poison input attributes need special handling as they are not
// DenseElementsAttr. If an index is poison, we select the first element of
// the first non-poison input.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] To me this is a fairly significant (and not immediately intuitive) part of the design. Perhaps move above the signature?

Also, is this based on some prior-art? Just curious, this does make sense.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow the prior-art part. Do you mean why we pick the first element of the first non-poison input? Poison is basically UB so given that we can't represent a partially poison vector we just make a random decision, which is ok as part of the UB behavior.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's valid to substitute poison with an arbitrary value

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's valid to substitute poison with an arbitrary value

Sure, but we are selecting a specific "arbitrary value" :)

Not sure I follow the prior-art part.

I was just curious whether there's any rationale behind this specific option. For example, something else in LLVM or MLIR makes similar choice?

Basically, what I'm missing is "why would we select the first element"? Something along the lines would be helpful:

I doesn't matter what we select, but we need to make a choice. We choose the first element.

Copy link
Member

@kuhar kuhar Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but we are selecting a specific "arbitrary value" :)

??? In this context, arbitrary is synonymous to non-deterministics, as in: absolutely any value will do and the choice doesn't have to be fair by any definition of fair.

@dcaballe
Copy link
Contributor Author

dcaballe commented Feb 4, 2025

It looks like Github has been in "Processing updates" stage (see top of the page) for almost an hour... Weird...

@dcaballe dcaballe merged commit c6eef00 into llvm:main Feb 5, 2025
6 of 7 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
llvm#124863 added folding support
for poison indices to `vector.shuffle`. This PR adds support for folding
`vector.shuffle` ops with one or two poison input vectors.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants